#!/usr/bin/env python3
import sys, os, zlib, struct, hashlib
from hexdump import hexdump
from tinygrad.helpers import DEBUG, getenv, fetch
from tinygrad.runtime.support.usb import USB3

def patch(input_filepath, file_hash, patches):
  with open(input_filepath, 'rb') as infile: data = bytearray(infile.read())

  if_hash = hashlib.md5(data).hexdigest()
  if if_hash != file_hash:
    raise ValueError(f"File hash mismatch: expected {file_hash}, got {if_hash}")

  for offset, expected_bytes, new_bytes in patches:
    if len(expected_bytes) != len(new_bytes):
      raise ValueError("Expected bytes and new bytes must be the same length")

    if offset + len(new_bytes) > len(data): return False
    current_bytes = data[offset:offset + len(expected_bytes)]
    assert bytes(current_bytes) == expected_bytes, f"Expected {expected_bytes} at offset {offset:x}, but got {current_bytes}"
    data[offset:offset + len(new_bytes)] = new_bytes

  checksum = sum(data[4:-6]) & 0xff
  crc32 = zlib.crc32(data[4:-6]).to_bytes(4, 'little')
  data[-5] = checksum
  data[-4] = crc32[0]
  data[-3] = crc32[1]
  data[-2] = crc32[2]
  data[-1] = crc32[3]
  return data

path = os.path.dirname(os.path.abspath(__file__))
file_hash = "5284e618d96ef804c06f47f3b73656b7"
file_path = os.path.join(path, "Software/AS_USB4_240417_85_00_00.bin")

if not os.path.exists(file_path):
  url = "https://web.archive.org/web/20250430124720/https://www.station-drivers.com/index.php/en/component/remository/func-download/6341/chk,3ef8b04704a18eb2fc57ff60382379ad/no_html,1/lang,en-gb/"
  os.system(f'curl -o "{path}/fw.zip" "{url}"')
  os.system(f'unzip -o "{path}/fw.zip" "Software/AS_USB4_240417_85_00_00.bin" -d "{path}"')

patches = [(0x2a0d + 1 + 4, b'\x0a', b'\x05')]
patched_fw = patch(file_path, file_hash, patches)

vendor, device = [int(x, base=16) for x in getenv("USBDEV", "174C:2464").split(":")]
try: dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04)
except RuntimeError as e:
  raise RuntimeError(f'{e}. You can set USBDEV environment variable to your device\'s vendor and device ID (e.g., USBDEV="174C:2464")') from e

config1 = bytes([
  0xFF, 0xFF, 0xFF, 0xFF, 0x41, 0x41, 0x41, 0x41, 0x42, 0x42, 0x42, 0x42, 0x30, 0x30, 0x36, 0x30,
  0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x74, 0x69, 0x6E, 0x79, 0xFF, 0xFF, 0xFF, 0xFF,
  0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
  0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x74, 0x69, 0x6E, 0x79,
  0xFF, 0xFF, 0xFF, 0xFF, 0x55, 0x53, 0x42, 0x20, 0x33, 0x2E, 0x32, 0x20, 0x50, 0x43, 0x49, 0x65,
  0x20, 0x54, 0x69, 0x6E, 0x79, 0x45, 0x6E, 0x63, 0x6C, 0x6F, 0x73, 0x75, 0x72, 0x65, 0xFF, 0xFF,
  0xFF, 0xFF, 0xFF, 0xFF, 0x54, 0x69, 0x6E, 0x79, 0x45, 0x6E, 0x63, 0x6C, 0x6F, 0x73, 0x75, 0x72,
  0x65, 0xFF, 0xFF, 0xFF, 0xD1, 0xAD, 0x01, 0x00, 0x00, 0x01, 0xCF, 0xFF, 0x02, 0xFF, 0x5A, 0x94])

config2 = bytes([
  0xFF, 0xFF, 0xFF, 0xFF, 0x47, 0x6F, 0x70, 0x6F, 0x64, 0x20, 0x47, 0x72, 0x6F, 0x75, 0x70, 0x20,
  0x4C, 0x69, 0x6D, 0x69, 0x74, 0x65, 0x64, 0x2E, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
  0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x55, 0x53, 0x42, 0x34,
  0x20, 0x4E, 0x56, 0x4D, 0x65, 0x20, 0x53, 0x53, 0x44, 0x20, 0x50, 0x72, 0x6F, 0x20, 0x45, 0x6E,
  0x63, 0x6C, 0x6F, 0x73, 0x75, 0x72, 0x65, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
  0xFF, 0xFF, 0xFF, 0xFF, 0x8C, 0xBF, 0xFF, 0x97, 0xC1, 0xF3, 0xFF, 0xFF, 0x01, 0x2D, 0x66, 0xD6,
  0x66, 0x06, 0x00, 0xC0, 0x87, 0x01, 0x5A, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xCA, 0x01, 0x66, 0xD6,
  0xE3, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x01, 0x00, 0xA5, 0x67])

part1 = patched_fw[:0xff00]
part2 = patched_fw[0xff00:]

# config patch
cdb = struct.pack('>BBB12x', 0xe1, 0x50, 0x0)
dev.send_batch(cdbs=[cdb], odata=[config1])

cdb = struct.pack('>BBB12x', 0xe1, 0x50, 0x1)
dev.send_batch(cdbs=[cdb], odata=[config2])

cdb = struct.pack('>BBI', 0xe3, 0x50, len(part1))
dev.send_batch(cdbs=[cdb], odata=[part1])

cdb = struct.pack('>BBI', 0xe3, 0xd0, len(part2))
dev.send_batch(cdbs=[cdb], odata=[part2])

cdb = struct.pack('>BB13x', 0xe8, 0x51)
dev.send_batch(cdbs=[cdb])

print("done, you can disconnect the controller!")
